/*
 * Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "cutlass/arch/mma.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"

// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass::gemm::collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace detail {

// Returns the maximum number of smem tiles that can be used with a given smem capacity, or
// overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB,
          int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(
    StageCountAutoCarveout<carveout_bytes> stage_count) {
  // 32 bytes to account for barriers etc.
  constexpr int stage_barrier_bytes = 32;
  constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
  constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
  constexpr int stage_bytes = [&]() -> int {
    if constexpr (SwapAB) {
      return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 +
             (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
    } else {
      return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
             (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 +
             stage_barrier_bytes;
    }
  }();

  return (CapacityBytes - carveout_bytes) / stage_bytes;
}

}  // namespace detail

/////////////////////////////////////////////////////////////////////////////////////////////////

// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
          int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
          class StageCountType, class KernelScheduleType,
          template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
    arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB, GmemLayoutB,
    AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
    KernelScheduleType, Activation, SwapAB,
    cute::enable_if_t<
        (cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
         cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
         cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
         cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
        not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>> {
  static_assert(is_static<TileShape_MNK>::value);
  static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
  static_assert(cutlass::detail::dependent_false<ElementA>,
                "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
  static_assert(
      detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
      "Should meet TMA alignment requirement\n");

  static constexpr bool IsArrayOfPointersGemm =
      (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
  static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
  static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
                "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 "
                "FastAccum version right now\n");

  // For fp32 types, map to tf32 MMA value type
  using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
  using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;

  static constexpr cute::GMMA::Major GmmaMajorA =
      detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
  static constexpr cute::GMMA::Major GmmaMajorB =
      detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();

  using AtomLayoutMNK = cute::conditional_t<
      cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
          IsArrayOfPointersGemm,
      Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;

  using TiledMma = decltype(cute::make_tiled_mma(
      cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK,
                                 GmmaMajorA, GmmaMajorB>(),
      AtomLayoutMNK{}));

  using GmemTiledCopyA =
      decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  using GmemTiledCopyB =
      decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));

  using SmemLayoutAtomA =
      decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
                                        decltype(cute::get<0>(TileShape_MNK{})),
                                        decltype(cute::get<2>(TileShape_MNK{}))>());
  using SmemLayoutAtomB =
      decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
                                        decltype(cute::get<1>(TileShape_MNK{})),
                                        decltype(cute::get<2>(TileShape_MNK{}))>());

  static constexpr int PipelineStages =
      detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA,
                                                    MmaElementB, TileShape_MNK, SwapAB>(
          StageCountType{});
  using DispatchPolicy = cute::conditional_t<
      IsArrayOfPointersGemm,
      MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
      /* For FP8 use a separate mainloop compared to other datatypes */
      cute::conditional_t<IsFP8Input,
                          MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK,
                                                                KernelScheduleType>,
                          MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
                                                             KernelScheduleType>>>;

  using SmemCopyAtomA = void;
  using SmemCopyAtomB = void;

  using CollectiveOp =
      CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
                         ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
                         SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
                         SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
          int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
          class StageCountType, class KernelScheduleType,
          template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
    arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB, GmemLayoutB,
    AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
    KernelScheduleType, Activation, SwapAB,
    cute::enable_if_t<
        cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
        cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
        cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
        cute::is_same_v<KernelScheduleType,
                        KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
  static_assert(is_static<TileShape_MNK>::value);
  static_assert(is_static<ClusterShape_MNK>::value);
  static_assert(
      detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
      "Not meet TMA alignment requirement yet\n");
  static_assert(detail::is_input_fp8<ElementA, ElementB>(),
                "Only FP8 datatypes are compatible with these kernel schedules\n");
  // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
  static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
                "Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
  static_assert(cutlass::detail::dependent_false<ElementA>,
                "Unsupported Toolkit for SM90 Collective Builder\n");
#endif

  static constexpr cute::GMMA::Major GmmaMajorA =
      detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
  static constexpr cute::GMMA::Major GmmaMajorB =
      detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();

  static constexpr bool IsArrayOfPointersGemm =
      (cute::is_same_v<KernelScheduleType,
                       KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
  using AtomLayoutMNK = cute::conditional_t<
      cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
          IsArrayOfPointersGemm,
      Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;

  using TiledMma = decltype(cute::make_tiled_mma(
      cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA,
                                 GmmaMajorB>(),
      AtomLayoutMNK{}));

  using GmemTiledCopyA =
      decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  using GmemTiledCopyB =
      decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));

  using SmemLayoutAtomA =
      decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
                                        decltype(cute::get<0>(TileShape_MNK{})),
                                        decltype(cute::get<2>(TileShape_MNK{}))>());
  using SmemLayoutAtomB =
      decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
                                        decltype(cute::get<1>(TileShape_MNK{})),
                                        decltype(cute::get<2>(TileShape_MNK{}))>());

  static constexpr int PipelineStages =
      detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA,
                                                    ElementB, TileShape_MNK, SwapAB>(
          StageCountType{});
  using DispatchPolicy = cute::conditional_t<
      IsArrayOfPointersGemm,
      MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
      MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;

  using SmemCopyAtomA = void;
  using SmemCopyAtomB = void;

  using CollectiveOp =
      CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
                         ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
                         SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
                         SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////
